from data_handle.data_loader import create_data_loader
from config import Bert_Config



if __name__ == '__main__':


    bt_config = Bert_Config()
    train_path = bt_config.train_json_path
    batch_size =bt_config.batch_size

    train_loader = create_data_loader(train_path, batch_size)

    for step, batch in enumerate(train_loader):
        print(batch['input_ids'].shape)
        print(batch['attention_mask'].shape)
        print(batch['labels'].shape)
        print('------------------')

        if step == 10 : break  # 先看第一批数据
    ...
